{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RAPTOR: Recursive Abstractive Processing and Thematic Organization for Retrieval\n",
    "\n",
    "## Overview\n",
    "RAPTOR is an advanced information retrieval and question-answering system that combines hierarchical document summarization, embedding-based retrieval, and contextual answer generation. It aims to efficiently handle large document collections by creating a multi-level tree of summaries, allowing for both broad and detailed information retrieval.\n",
    "\n",
    "## Motivation\n",
    "Traditional retrieval systems often struggle with large document sets, either missing important details or getting overwhelmed by irrelevant information. RAPTOR addresses this by creating a hierarchical structure of the document collection, allowing it to navigate between high-level concepts and specific details as needed.\n",
    "\n",
    "## Key Components\n",
    "1. **Tree Building**: Creates a hierarchical structure of document summaries.\n",
    "2. **Embedding and Clustering**: Organizes documents and summaries based on semantic similarity.\n",
    "3. **Vectorstore**: Efficiently stores and retrieves document and summary embeddings.\n",
    "4. **Contextual Retriever**: Selects the most relevant information for a given query.\n",
    "5. **Answer Generation**: Produces coherent responses based on retrieved information.\n",
    "\n",
    "## Method Details\n",
    "\n",
    "### Tree Building\n",
    "1. Start with original documents at level 0.\n",
    "2. For each level:\n",
    "   - Embed the texts using a language model.\n",
    "   - Cluster the embeddings (e.g., using Gaussian Mixture Models).\n",
    "   - Generate summaries for each cluster.\n",
    "   - Use these summaries as the texts for the next level.\n",
    "3. Continue until reaching a single summary or a maximum level.\n",
    "\n",
    "### Embedding and Retrieval\n",
    "1. Embed all documents and summaries from all levels of the tree.\n",
    "2. Store these embeddings in a vectorstore (e.g., FAISS) for efficient similarity search.\n",
    "3. For a given query:\n",
    "   - Embed the query.\n",
    "   - Retrieve the most similar documents/summaries from the vectorstore.\n",
    "\n",
    "### Contextual Compression\n",
    "1. Take the retrieved documents/summaries.\n",
    "2. Use a language model to extract only the most relevant parts for the given query.\n",
    "\n",
    "### Answer Generation\n",
    "1. Combine the relevant parts into a context.\n",
    "2. Use a language model to generate an answer based on this context and the original query.\n",
    "\n",
    "## Benefits of this Approach\n",
    "1. **Scalability**: Can handle large document collections by working with summaries at different levels.\n",
    "2. **Flexibility**: Capable of providing both high-level overviews and specific details.\n",
    "3. **Context-Awareness**: Retrieves information from the most appropriate level of abstraction.\n",
    "4. **Efficiency**: Uses embeddings and vectorstore for fast retrieval.\n",
    "5. **Traceability**: Maintains links between summaries and original documents, allowing for source verification.\n",
    "\n",
    "## Conclusion\n",
    "RAPTOR represents a significant advancement in information retrieval and question-answering systems. By combining hierarchical summarization with embedding-based retrieval and contextual answer generation, it offers a powerful and flexible approach to handling large document collections. The system's ability to navigate different levels of abstraction allows it to provide relevant and contextually appropriate answers to a wide range of queries.\n",
    "\n",
    "While RAPTOR shows great promise, future work could focus on optimizing the tree-building process, improving summary quality, and enhancing the retrieval mechanism to better handle complex, multi-faceted queries. Additionally, integrating this approach with other AI technologies could lead to even more sophisticated information processing systems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"text-align: center;\">\n",
    "\n",
    "<img src=\"../images/raptor.svg\" alt=\"RAPTOR\" style=\"width:100%; height:auto;\">\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from typing import List, Dict, Any\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from langchain.embeddings import OpenAIEmbeddings\n",
    "from langchain.vectorstores import FAISS\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.retrievers import ContextualCompressionRetriever\n",
    "from langchain.retrievers.document_compressors import LLMChainExtractor\n",
    "from langchain.schema import AIMessage\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import logging\n",
    "import os\n",
    "import sys\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
    "from helper_functions import *\n",
    "from evaluation.evalute_rag import *\n",
    "\n",
    "# Load environment variables from a .env file\n",
    "load_dotenv()\n",
    "\n",
    "# Set the OpenAI API key environment variable\n",
    "os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define logging, llm and embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up logging\n",
    "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
    "\n",
    "embeddings = OpenAIEmbeddings()\n",
    "llm = ChatOpenAI(model_name=\"gpt-4o-mini\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper Functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_text(item):\n",
    "    \"\"\"Extract text content from either a string or an AIMessage object.\"\"\"\n",
    "    if isinstance(item, AIMessage):\n",
    "        return item.content\n",
    "    return item\n",
    "\n",
    "def embed_texts(texts: List[str]) -> List[List[float]]:\n",
    "    \"\"\"Embed texts using OpenAIEmbeddings.\"\"\"\n",
    "    logging.info(f\"Embedding {len(texts)} texts\")\n",
    "    return embeddings.embed_documents([extract_text(text) for text in texts])\n",
    "\n",
    "def perform_clustering(embeddings: np.ndarray, n_clusters: int = 10) -> np.ndarray:\n",
    "    \"\"\"Perform clustering on embeddings using Gaussian Mixture Model.\"\"\"\n",
    "    logging.info(f\"Performing clustering with {n_clusters} clusters\")\n",
    "    gm = GaussianMixture(n_components=n_clusters, random_state=42)\n",
    "    return gm.fit_predict(embeddings)\n",
    "\n",
    "def summarize_texts(texts: List[str]) -> str:\n",
    "    \"\"\"Summarize a list of texts using OpenAI.\"\"\"\n",
    "    logging.info(f\"Summarizing {len(texts)} texts\")\n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Summarize the following text concisely:\\n\\n{text}\"\n",
    "    )\n",
    "    chain = prompt | llm\n",
    "    input_data = {\"text\": texts}\n",
    "    return chain.invoke(input_data)\n",
    "\n",
    "def visualize_clusters(embeddings: np.ndarray, labels: np.ndarray, level: int):\n",
    "    \"\"\"Visualize clusters using PCA.\"\"\"\n",
    "    from sklearn.decomposition import PCA\n",
    "    pca = PCA(n_components=2)\n",
    "    reduced_embeddings = pca.fit_transform(embeddings)\n",
    "    \n",
    "    plt.figure(figsize=(10, 8))\n",
    "    scatter = plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=labels, cmap='viridis')\n",
    "    plt.colorbar(scatter)\n",
    "    plt.title(f'Cluster Visualization - Level {level}')\n",
    "    plt.xlabel('First Principal Component')\n",
    "    plt.ylabel('Second Principal Component')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RAPTOR Core Function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def build_raptor_tree(texts: List[str], max_levels: int = 3) -> Dict[int, pd.DataFrame]:\n",
    "    \"\"\"Build the RAPTOR tree structure with level metadata and parent-child relationships.\"\"\"\n",
    "    results = {}\n",
    "    current_texts = [extract_text(text) for text in texts]\n",
    "    current_metadata = [{\"level\": 0, \"origin\": \"original\", \"parent_id\": None} for _ in texts]\n",
    "    \n",
    "    for level in range(1, max_levels + 1):\n",
    "        logging.info(f\"Processing level {level}\")\n",
    "        \n",
    "        embeddings = embed_texts(current_texts)\n",
    "        n_clusters = min(10, len(current_texts) // 2)\n",
    "        cluster_labels = perform_clustering(np.array(embeddings), n_clusters)\n",
    "        \n",
    "        df = pd.DataFrame({\n",
    "            'text': current_texts,\n",
    "            'embedding': embeddings,\n",
    "            'cluster': cluster_labels,\n",
    "            'metadata': current_metadata\n",
    "        })\n",
    "        \n",
    "        results[level-1] = df\n",
    "        \n",
    "        summaries = []\n",
    "        new_metadata = []\n",
    "        for cluster in df['cluster'].unique():\n",
    "            cluster_docs = df[df['cluster'] == cluster]\n",
    "            cluster_texts = cluster_docs['text'].tolist()\n",
    "            cluster_metadata = cluster_docs['metadata'].tolist()\n",
    "            summary = summarize_texts(cluster_texts)\n",
    "            summaries.append(summary)\n",
    "            new_metadata.append({\n",
    "                \"level\": level,\n",
    "                \"origin\": f\"summary_of_cluster_{cluster}_level_{level-1}\",\n",
    "                \"child_ids\": [meta.get('id') for meta in cluster_metadata],\n",
    "                \"id\": f\"summary_{level}_{cluster}\"\n",
    "            })\n",
    "        \n",
    "        current_texts = summaries\n",
    "        current_metadata = new_metadata\n",
    "        \n",
    "        if len(current_texts) <= 1:\n",
    "            results[level] = pd.DataFrame({\n",
    "                'text': current_texts,\n",
    "                'embedding': embed_texts(current_texts),\n",
    "                'cluster': [0],\n",
    "                'metadata': current_metadata\n",
    "            })\n",
    "            logging.info(f\"Stopping at level {level} as we have only one summary\")\n",
    "            break\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vectorstore Function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vectorstore(tree_results: Dict[int, pd.DataFrame]) -> FAISS:\n",
    "    \"\"\"Build a FAISS vectorstore from all texts in the RAPTOR tree.\"\"\"\n",
    "    all_texts = []\n",
    "    all_embeddings = []\n",
    "    all_metadatas = []\n",
    "    \n",
    "    for level, df in tree_results.items():\n",
    "        all_texts.extend([str(text) for text in df['text'].tolist()])\n",
    "        all_embeddings.extend([embedding.tolist() if isinstance(embedding, np.ndarray) else embedding for embedding in df['embedding'].tolist()])\n",
    "        all_metadatas.extend(df['metadata'].tolist())\n",
    "    \n",
    "    logging.info(f\"Building vectorstore with {len(all_texts)} texts\")\n",
    "    \n",
    "    # Create Document objects manually to ensure correct types\n",
    "    documents = [Document(page_content=str(text), metadata=metadata) \n",
    "                 for text, metadata in zip(all_texts, all_metadatas)]\n",
    "    \n",
    "    return FAISS.from_documents(documents, embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define tree traversal retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tree_traversal_retrieval(query: str, vectorstore: FAISS, k: int = 3) -> List[Document]:\n",
    "    \"\"\"Perform tree traversal retrieval.\"\"\"\n",
    "    query_embedding = embeddings.embed_query(query)\n",
    "    \n",
    "    def retrieve_level(level: int, parent_ids: List[str] = None) -> List[Document]:\n",
    "        if parent_ids:\n",
    "            docs = vectorstore.similarity_search_by_vector_with_relevance_scores(\n",
    "                query_embedding,\n",
    "                k=k,\n",
    "                filter=lambda meta: meta['level'] == level and meta['id'] in parent_ids\n",
    "            )\n",
    "        else:\n",
    "            docs = vectorstore.similarity_search_by_vector_with_relevance_scores(\n",
    "                query_embedding,\n",
    "                k=k,\n",
    "                filter=lambda meta: meta['level'] == level\n",
    "            )\n",
    "        \n",
    "        if not docs or level == 0:\n",
    "            return docs\n",
    "        \n",
    "        child_ids = [doc.metadata.get('child_ids', []) for doc, _ in docs]\n",
    "        child_ids = [item for sublist in child_ids for item in sublist]  # Flatten the list\n",
    "        \n",
    "        child_docs = retrieve_level(level - 1, child_ids)\n",
    "        return docs + child_docs\n",
    "    \n",
    "    max_level = max(doc.metadata['level'] for doc in vectorstore.docstore.values())\n",
    "    return retrieve_level(max_level)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Retriever\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_retriever(vectorstore: FAISS) -> ContextualCompressionRetriever:\n",
    "    \"\"\"Create a retriever with contextual compression.\"\"\"\n",
    "    logging.info(\"Creating contextual compression retriever\")\n",
    "    base_retriever = vectorstore.as_retriever()\n",
    "    \n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Given the following context and question, extract only the relevant information for answering the question:\\n\\n\"\n",
    "        \"Context: {context}\\n\"\n",
    "        \"Question: {question}\\n\\n\"\n",
    "        \"Relevant Information:\"\n",
    "    )\n",
    "    \n",
    "    extractor = LLMChainExtractor.from_llm(llm, prompt=prompt)\n",
    "    \n",
    "    return ContextualCompressionRetriever(\n",
    "        base_compressor=extractor,\n",
    "        base_retriever=base_retriever\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define hierarchical retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hierarchical_retrieval(query: str, retriever: ContextualCompressionRetriever, max_level: int) -> List[Document]:\n",
    "    \"\"\"Perform hierarchical retrieval starting from the highest level, handling potential None values.\"\"\"\n",
    "    all_retrieved_docs = []\n",
    "    \n",
    "    for level in range(max_level, -1, -1):\n",
    "        # Retrieve documents from the current level\n",
    "        level_docs = retriever.get_relevant_documents(\n",
    "            query,\n",
    "            filter=lambda meta: meta['level'] == level\n",
    "        )\n",
    "        all_retrieved_docs.extend(level_docs)\n",
    "        \n",
    "        # If we've found documents, retrieve their children from the next level down\n",
    "        if level_docs and level > 0:\n",
    "            child_ids = [doc.metadata.get('child_ids', []) for doc in level_docs]\n",
    "            child_ids = [item for sublist in child_ids for item in sublist if item is not None]  # Flatten and filter None\n",
    "            \n",
    "            if child_ids:  # Only modify query if there are valid child IDs\n",
    "                child_query = f\" AND id:({' OR '.join(str(id) for id in child_ids)})\"\n",
    "                query += child_query\n",
    "    \n",
    "    return all_retrieved_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RAPTOR Query Process (Online Process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "def raptor_query(query: str, retriever: ContextualCompressionRetriever, max_level: int) -> Dict[str, Any]:\n",
    "    \"\"\"Process a query using the RAPTOR system with hierarchical retrieval.\"\"\"\n",
    "    logging.info(f\"Processing query: {query}\")\n",
    "    \n",
    "    relevant_docs = hierarchical_retrieval(query, retriever, max_level)\n",
    "    \n",
    "    doc_details = []\n",
    "    for i, doc in enumerate(relevant_docs, 1):\n",
    "        doc_details.append({\n",
    "            \"index\": i,\n",
    "            \"content\": doc.page_content,\n",
    "            \"metadata\": doc.metadata,\n",
    "            \"level\": doc.metadata.get('level', 'Unknown'),\n",
    "            \"similarity_score\": doc.metadata.get('score', 'N/A')\n",
    "        })\n",
    "    \n",
    "    context = \"\\n\\n\".join([doc.page_content for doc in relevant_docs])\n",
    "    \n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Given the following context, please answer the question:\\n\\n\"\n",
    "        \"Context: {context}\\n\\n\"\n",
    "        \"Question: {question}\\n\\n\"\n",
    "        \"Answer:\"\n",
    "    )\n",
    "    chain = LLMChain(llm=llm, prompt=prompt)\n",
    "    answer = chain.run(context=context, question=query)\n",
    "    \n",
    "    logging.info(\"Query processing completed\")\n",
    "    \n",
    "    result = {\n",
    "        \"query\": query,\n",
    "        \"retrieved_documents\": doc_details,\n",
    "        \"num_docs_retrieved\": len(relevant_docs),\n",
    "        \"context_used\": context,\n",
    "        \"answer\": answer,\n",
    "        \"model_used\": llm.model_name,\n",
    "    }\n",
    "    \n",
    "    return result\n",
    "\n",
    "\n",
    "def print_query_details(result: Dict[str, Any]):\n",
    "    \"\"\"Print detailed information about the query process, including tree level metadata.\"\"\"\n",
    "    print(f\"Query: {result['query']}\")\n",
    "    print(f\"\\nNumber of documents retrieved: {result['num_docs_retrieved']}\")\n",
    "    print(f\"\\nRetrieved Documents:\")\n",
    "    for doc in result['retrieved_documents']:\n",
    "        print(f\"  Document {doc['index']}:\")\n",
    "        print(f\"    Content: {doc['content'][:100]}...\")  # Show first 100 characters\n",
    "        print(f\"    Similarity Score: {doc['similarity_score']}\")\n",
    "        print(f\"    Tree Level: {doc['metadata'].get('level', 'Unknown')}\")\n",
    "        print(f\"    Origin: {doc['metadata'].get('origin', 'Unknown')}\")\n",
    "        if 'child_docs' in doc['metadata']:\n",
    "            print(f\"    Number of Child Documents: {len(doc['metadata']['child_docs'])}\")\n",
    "        print()\n",
    "    \n",
    "    print(f\"\\nContext used for answer generation:\")\n",
    "    print(result['context_used'])\n",
    "    \n",
    "    print(f\"\\nGenerated Answer:\")\n",
    "    print(result['answer'])\n",
    "    \n",
    "    print(f\"\\nModel Used: {result['model_used']}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example Usage and Visualization\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define data folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../data/Understanding_Climate_Change.pdf\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Process texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyPDFLoader(path)\n",
    "documents = loader.load()\n",
    "texts = [doc.page_content for doc in documents]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create RAPTOR components instances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the RAPTOR tree\n",
    "tree_results = build_raptor_tree(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build vectorstore\n",
    "vectorstore = build_vectorstore(tree_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create retriever\n",
    "retriever = create_retriever(vectorstore)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run a query and observe where it got the data from + results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the pipeline\n",
    "max_level = 3  # Adjust based on your tree depth\n",
    "query = \"What is the greenhouse effect?\"\n",
    "result = raptor_query(query, retriever, max_level)\n",
    "print_query_details(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
